
#############################################
### Search for heterogeneous effects      ###
#############################################


library(dplyr)
library(estimatr)
library(ggplot2)
library(tidyr)
library(mediation)
library(ri2)
options(java.parameters = "-Xmx6g")
library(bartMachine)

#set directory to Data/folder

# load and recode (copied from results files) ####
dat <- read.csv("combined_dataset.csv")

dat <- dat %>% 
  mutate(matched = case_when(is.na(posterior)~0,
                             T~1))


dat <- dat %>%
  mutate(sex = case_when(x_f_ch_male==1|x_f_ad_male==1~"M",
                         x_f_ch_male==0|x_f_ad_male==0~"F",
                         is.na(x_f_ch_male)&is.na(x_f_ad_male)&f_svy_gender=="M"~"M",
                         is.na(x_f_ch_male)&is.na(x_f_ad_male)&f_svy_gender=="F"~"F"))

dat <- dat %>%
  mutate(age_group = case_when((ra_year - f_svy_yob_imp_ytgc)>=13~"old_kid",
                               (ra_year - f_svy_yob_imp_ytgc)<13~"young_kid",
                               f_svy_sample2007 %in% c("AD", "ES")~"adult"))

#remove anyone without randomization group (mostly people only in interim/roster)
dat <- dat %>% filter(!is.na(ra_group)&!is.na(age_group))

#treat incorrect matches as non-matches
#if someone voted before age 18, set all to 0
dat$badmatch <- ifelse(!is.na(dat$r_pretreatturnout)&dat$age_group!="adult", 1, 0)

dat$matched <- ifelse(dat$badmatch==1, 0, dat$matched)

# set turnout to 0 if someone didn't match to the voter file
dat <- dat %>% 
  mutate(r_postturnout = case_when(matched==0~0,
                                   T~r_postturnout)) %>%
  mutate(r_pretreatturnout = case_when(matched==0~0,
                                       T~r_pretreatturnout))
dat$r_postturnout <- ifelse(is.na(dat$r_postturnout), 0, dat$r_postturnout)

# create ever voted indicator
dat <- dat %>%
  mutate(evervoted_post = case_when(r_postturnout>0~1,
                                    T~0)) %>%
  mutate(evervoted_pre = case_when(r_pretreatturnout>0~1,
                                   T~0))
dat <- dat %>%
  mutate(r_postturnout = case_when(badmatch==1~0,
                                   T~r_postturnout)) %>%
  mutate(r_pretreatturnout = case_when(badmatch==1~0,
                                       T~r_pretreatturnout)) %>%
  mutate(r_postregturnout = case_when(badmatch==1~NA_real_,
                                      T~r_postregturnout)) %>%
  mutate(evervoted_post = case_when(badmatch==1~0,
                                    T~evervoted_post)) %>%
  mutate(evervoted_pre = case_when(badmatch==1~0,
                                   T~evervoted_pre)) 


#covariates for control models
covs_ad <- c("x_f_ad_36_40", "x_f_ad_41_45", "x_f_ad_46_50",
             "x_f_ad_edged", "x_f_ad_edgradhs", "x_f_ad_edgradhs_miss", "x_f_ad_edinsch",
             "x_f_ad_ethn_hisp", "x_f_ad_le_35", "x_f_ad_male",
             "x_f_ad_nevmarr", "x_f_ad_parentu18", "x_f_ad_race_black", "x_f_ad_race_other", "x_f_ad_working",
             "x_f_hh_afdc", "x_f_hh_car", "x_f_hh_disabl", "x_f_hh_noteens",
             "x_f_hh_size2", "x_f_hh_size3", "x_f_hh_size4", "x_f_hh_victim",
             "x_f_hood_5y", "x_f_hood_chat", "x_f_hood_nbrkid", "x_f_hood_nofamily",
             "x_f_hood_nofriend", "x_f_hood_unsafenit", "x_f_hood_verydissat",
             "x_f_hous_fndapt", "x_f_hous_mov3tm", "x_f_hous_movdrgs", "x_f_hous_movschl", "x_f_hous_sec8bef",
             "x_f_site_balt", "x_f_site_bos", "x_f_site_chi", "x_f_site_la")



dat <- dat %>%
  mutate(race = case_when(x_f_ad_race_black==1~"black",
                          x_f_ad_ethn_hisp==1~"hisp",
                          x_f_ad_race_other==1&x_f_ad_ethn_hisp==0~"other",
                          T~"white"))

dat <- dat %>% rename("local_poverty_posttreat"=f_c9010t_perpov_dw)

dat <- dat %>%
  mutate(ad_educ = case_when(hed2==4~"no hs",
                             (hed2 %in% c(1,2))&hed3==5~"hs",
                             (hed4 %in% c(1,2))~"aa",
                             (hed4 %in% c(3,4))~"ba+" ))
dat$ad_educ <- factor(dat$ad_educ, 
                      levels=c("no hs", "hs", "aa", "ba+"),
                      ordered=T)

dat <- dat %>%
  mutate(yt_educ = case_when(yed3c==5&yed1==5~"no hs",
                             yed3c==1&(yed3a<=12)~"hs",
                             yed3a>12~"more than hs",
                             hho3 %in% c(1,2)~"hs",
                             hho3==3~"no hs",
                             hho4==1~"more than hs"))
dat$yt_educ <- factor(dat$yt_educ, 
                      levels=c("no hs", "hs", "more than hs"),
                      ordered=T)

dat <- dat %>%
  group_by(mto_pseudo_famid) %>%
  mutate(fam_moves = mean(a22, na.rm=T))







# adults ####

# focus only on experimental group and control:
datb <- dat %>% filter(ra_group_factor!="section 8") %>%
  mutate(treat = case_when(ra_group_factor=="experimental"~1,
                           ra_group_factor=="control"~0)) %>%
  filter(age_group=="adult")

# select pretreatment variables that could moderate effect
xmat <- datb %>% dplyr::select(treat, x_f_ad_edged, x_f_ad_edgradhs, x_f_ad_edgradhs_miss,
                               x_f_ad_edinsch,
                               x_f_ad_36_40, x_f_ad_41_45, x_f_ad_46_50,
                               x_f_ad_ethn_hisp,
                               x_f_ad_le_35,
                               x_f_ad_male, x_f_ad_nevmarr, x_f_ad_parentu18,
                               x_f_ad_race_black, x_f_ad_race_other, x_f_ad_working,
                               x_f_hh_afdc, x_f_hh_car, x_f_hh_disabl,
                               x_f_hh_noteens, x_f_hh_size2, x_f_hh_size3,
                               x_f_hh_size4, x_f_hh_victim, x_f_hood_5y,
                               x_f_hood_chat, x_f_hood_nbrkid, x_f_hood_nofamily,
                               x_f_hood_nofriend, x_f_hood_unsafenit,
                               x_f_hood_verydissat, x_f_hous_fndapt, x_f_hous_mov3tm,
                               x_f_hous_movdrgs, x_f_hous_movschl, x_f_hous_sec8bef,
                               x_f_site_balt, x_f_site_bos, x_f_site_chi, x_f_site_la 
)

# fit a model searching for heterogeneous effects
b.fit <- bartMachine(X = xmat,
                     y = datb$r_postturnout,
                     num_trees = 50)
summary(b.fit)
plot_y_vs_yhat(b.fit, prediction_intervals = TRUE)
# poor coverage: don't proceed with analyzing results


# teens at random assignment ####

# limit to experimental and control groups
datb <- dat %>% filter(ra_group_factor!="section 8") %>%
  mutate(treat = case_when(ra_group_factor=="experimental"~1,
                           ra_group_factor=="control"~0)) %>%
  filter(age_group=="old_kid")

# select pretreatment variables that are possible moderators
xmat <- datb %>% dplyr::select(treat, x_f_ad_edged, x_f_ad_edgradhs, x_f_ad_edgradhs_miss,
                               x_f_ad_edinsch,
                               x_f_ad_36_40, x_f_ad_41_45, x_f_ad_46_50,
                               x_f_ad_ethn_hisp,
                               x_f_ad_le_35,
                               x_f_ad_male, x_f_ad_nevmarr, x_f_ad_parentu18,
                               x_f_ad_race_black, x_f_ad_race_other, x_f_ad_working,
                               x_f_hh_afdc, x_f_hh_car, x_f_hh_disabl,
                               x_f_hh_noteens, x_f_hh_size2, x_f_hh_size3,
                               x_f_hh_size4, x_f_hh_victim, x_f_hood_5y,
                               x_f_hood_chat, x_f_hood_nbrkid, x_f_hood_nofamily,
                               x_f_hood_nofriend, x_f_hood_unsafenit,
                               x_f_hood_verydissat, x_f_hous_fndapt, x_f_hous_mov3tm,
                               x_f_hous_movdrgs, x_f_hous_movschl, x_f_hous_sec8bef,
                               x_f_site_balt, x_f_site_bos, x_f_site_chi, x_f_site_la,
                               x_f_c1_behprb, x_f_c1_behprb_miss, x_f_c1_expel,
                               x_f_c1_expel_miss, x_f_c1_gifted, x_f_c1_gifted_miss,
                               x_f_c1_lrnprb, x_f_c1_lrnprb_miss, x_f_c1_schcll,
                               x_f_ch_age10, x_f_ch_age11, x_f_ch_age12, x_f_ch_age13,
                               x_f_ch_age14, x_f_ch_age15, x_f_ch_age16, x_f_ch_age17,
                               x_f_ch_age18, x_f_ch_age19, x_f_ch_age20, x_f_ch_male,
                               x_f_ch_schplay, x_f_ch_schplay_miss, x_f_ch_specmed)

# fit a model looking for heterogeneity
b.fit <- bartMachine(X = xmat,
                     y = datb$evervoted_post,
                     num_trees = 50)

# summarize models and plot results
summary(b.fit)
plot_y_vs_yhat(b.fit, prediction_intervals = TRUE)
b.fit.arr <- bartMachineArr(b.fit, R=10)
summary(b.fit.arr[[1]])
plot_y_vs_yhat(b.fit.arr[[1]], prediction_intervals = TRUE)

# look at most important variables for 10 runs of the model
# record the variables which appear in >1 model's top 3
investigate_var_importance(b.fit.arr[[1]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[2]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[3]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[4]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[5]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[6]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[7]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[8]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[9]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[10]], num_replicates_for_avg = 10)

# select those in the top 3 of more than 1 model
xmat <- datb %>%
  dplyr::select(treat, x_f_ch_male,
                x_f_ch_specmed, 
                x_f_site_la,
                x_f_site_balt,
                x_f_site_chi,
                x_f_ad_male)

# refit model with these variables
b.fit <- bartMachine(X = xmat,
                     y = datb$evervoted_post,
                     num_trees = 50)
b.fit.arr <- bartMachineArr(b.fit, R=10)

# calculate the treatment effects for each subrgroup of these top variables
hte_cal <- function(var){
  
  # predict outcome with control/0 on attribute
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 0
  xmat_sub[,var] <- 0
  b1 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b1 1 done")
  
  # predict outcome with treatment/0 on attribute
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 1
  xmat_sub[,var] <- 0
  b2 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b2 1 done")
  # calc ate for 0 on attribute
  ate_0 <- mean(b2 - b1)
  
  # predict outcome with control/1 on attribute
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 0
  xmat_sub[,var] <- 1
  b1 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b1 2 done")
  
  # predict outcome with treatment/1 on attribute
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 1
  xmat_sub[,var] <- 1
  b2 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b2 2 done")
  # calc ate for 1 on attribute
  ate_1 <- mean(b2 - b1)
  
  # calc difference in ATEs
  diff <- ate_1 - ate_0
  
  return(c("varname"=var, "ate_1"=ate_1, "ate_0"=ate_0, "diff"=diff))
}

# run function on all vars with >1 appearance 
htevars <- c("x_f_ch_male",
             "x_f_ch_specmed", 
             "x_f_site_la",
             "x_f_site_balt",
             "x_f_site_chi",
             "x_f_ad_male")

tab1 <- hte_cal(var="x_f_ch_male")
for(i in 2:length(htevars)){
  tab1 <- rbind(tab1, hte_cal(htevars[i]))
  paste("done", var)
}


# young kids at random assignment ####

# focus on experimental and  control groups
datb <- dat %>% filter(ra_group_factor!="section 8") %>%
  mutate(treat = case_when(ra_group_factor=="experimental"~1,
                           ra_group_factor=="control"~0)) %>%
  filter(age_group=="young_kid")

# first focus on post-treatment turnout

# grab pretreatment variables as candidates for moderators
xmat <- datb %>% dplyr::select(treat, x_f_ad_edged, x_f_ad_edgradhs, x_f_ad_edgradhs_miss,
                               x_f_ad_edinsch,
                               x_f_ad_36_40, x_f_ad_41_45, x_f_ad_46_50,
                               x_f_ad_ethn_hisp,
                               x_f_ad_le_35,
                               x_f_ad_male, x_f_ad_nevmarr, x_f_ad_parentu18,
                               x_f_ad_race_black, x_f_ad_race_other, x_f_ad_working,
                               x_f_hh_afdc, x_f_hh_car, x_f_hh_disabl,
                               x_f_hh_noteens, x_f_hh_size2, x_f_hh_size3,
                               x_f_hh_size4, x_f_hh_victim, x_f_hood_5y,
                               x_f_hood_chat, x_f_hood_nbrkid, x_f_hood_nofamily,
                               x_f_hood_nofriend, x_f_hood_unsafenit,
                               x_f_hood_verydissat, x_f_hous_fndapt, x_f_hous_mov3tm,
                               x_f_hous_movdrgs, x_f_hous_movschl, x_f_hous_sec8bef,
                               x_f_site_balt, x_f_site_bos, x_f_site_chi, x_f_site_la,
                               x_f_c2_hosp, x_f_c2_hosp_miss, x_f_c2_lowbw, x_f_c2_lowbw_miss,
                               x_f_c2_read, x_f_c2_read_miss,
                               x_f_ch_age10, x_f_ch_age11, x_f_ch_age12, x_f_ch_age13,
                               x_f_ch_age14, x_f_ch_age15, x_f_ch_age16, x_f_ch_age17,
                               x_f_ch_age18, x_f_ch_age19, x_f_ch_age20, x_f_ch_male,
                               x_f_ch_schplay, x_f_ch_schplay_miss, x_f_ch_specmed)

# fit model looking for heterogeneity
b.fit <- bartMachine(X = xmat,
                     y = datb$r_postturnout,
                     num_trees = 50)

# look at results
b.fit.arr <- bartMachineArr(b.fit, R=10)
summary(b.fit.arr[[1]])
plot_y_vs_yhat(b.fit.arr[[1]], prediction_intervals = TRUE)


# look at variable importance for each model run
# note any vars that appear in any top 10
investigate_var_importance(b.fit.arr[[1]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[2]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[3]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[4]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[5]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[6]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[7]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[8]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[9]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[10]], num_replicates_for_avg = 10)

# grab top 10 predictors plus treatment
# # denotes 1 model in which it was a top 10 important predictor
xmat <- datb %>% dplyr::select(treat, 
                               x_f_ch_age19, #######
                               x_f_ad_le_35, #######
                               x_f_c2_read_miss, ####
                               x_f_ad_edgradhs_miss, ##
                               x_f_c2_hosp, ####
                               x_f_site_balt, ####### 
                               x_f_ch_age10, #####
                               x_f_c2_hosp_miss, ####### 
                               x_f_ch_age11, ##
                               x_f_hood_verydissat, #
                               x_f_ad_nevmarr, ####
                               x_f_ad_race_other, ####
                               x_f_ch_age12, #######
                               x_f_ch_schplay, ###
                               x_f_ch_age13, ###
                               x_f_ch_age14, #
                               x_f_ad_race_black, ####
                               x_f_c2_lowbw_miss, ##
                               x_f_hh_size3, ###
                               x_f_c2_read, #
                               x_f_ch_male, ####
                               x_f_ch_specmed, #
                               x_f_ad_46_50, #
                               x_f_ch_age16, ####
                               x_f_c2_read_miss, ###
                               x_f_hh_size4, #
                               x_f_ad_41_45, ##
                               x_f_hh_afdc, #
                               x_f_hood_nofriend, #
                               x_f_hood_nbrkid, #
                               x_f_ad_male, #
                               x_f_hous_mov3tm) #
#rerun on just those predictors
b.fit <- bartMachine(X = xmat,
                     y = datb$r_postturnout,
                     num_trees = 50)
b.fit.arr <- bartMachineArr(b.fit, R=10)

# predict on new matrices, find diffs:
hte_cal <- function(var){
  
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 0
  xmat_sub[,var] <- 0
  b1 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b1 1 done")
  
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 1
  xmat_sub[,var] <- 0
  b2 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b2 1 done")
  ate_0 <- mean(b2 - b1)
  
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 0
  xmat_sub[,var] <- 1
  b1 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b1 2 done")
  
  xmat_sub <- xmat
  xmat_sub[,"treat"] <- 1
  xmat_sub[,var] <- 1
  b2 <- predict_bartMachineArr(b.fit.arr, new_data = xmat_sub)
  print("b2 2 done")
  
  ate_1 <- mean(b2 - b1)
  diff <- ate_1 - ate_0
  
  return(c("varname"=var, "ate_1"=ate_1, "ate_0"=ate_0, "diff"=diff))
}

# repeat for all vars with >1 appearance 
htevars <- c("x_f_ch_age19", "x_f_ad_le_35", "x_f_c2_read_miss", "x_f_ad_edgradhs_miss", 
             "x_f_c2_hosp", "x_f_site_balt", "x_f_ch_age10", "x_f_c2_hosp_miss", 
             "x_f_ch_age11", "x_f_ad_nevmarr", "x_f_ad_race_other", "x_f_ch_age12", 
             "x_f_ch_schplay", "x_f_ch_age13", "x_f_ad_race_black", "x_f_c2_lowbw_miss", 
             "x_f_hh_size3", "x_f_ch_male", "x_f_ch_age16", "x_f_c2_read_miss", "x_f_ad_41_45")

tab1 <- hte_cal(var="x_f_ch_age19")
for(i in 2:length(htevars)){
  tab1 <- rbind(tab1, hte_cal(htevars[i]))
  paste("done", var)
}

# plot differences for each candidate variable
ggplot(as.data.frame(tab1) %>%
         mutate(varname = recode(varname, x_f_site_balt="Site: Baltimore",
                                 x_f_ch_schplay="Play Sport",
                                 x_f_hh_size3="3 in household",
                                 x_f_ch_male="Sex: Male",
                                 x_f_ch_age19="Age: 19",
                                 x_f_ch_age16="Age: 16",
                                 x_f_ch_age13="Age: 13",
                                 x_f_ch_age12="Age: 12",
                                 x_f_ch_age11="Age: 11",
                                 x_f_ch_age10="Age: 10",
                                 x_f_c2_read_miss="Missing: read to child",
                                 x_f_c2_lowbw_miss="Missing: low birthweight",
                                 x_f_c2_hosp_miss="Missing: hospital stay",
                                 x_f_c2_hosp="Stayed in hospital as infant",
                                 x_f_ad_race_other="Adult race: other",
                                 x_f_ad_race_black="Adult race: black",
                                 x_f_ad_nevmarr="Adult: never married",
                                 x_f_ad_le_35="Adult age: <35",
                                 x_f_ad_edgradhs_miss="Missing: adult education",
                                 x_f_ad_41_45="Adult age: 41-45"))) +
  geom_point(aes(x=as.numeric(ate_1), y=varname)) + 
  geom_point(aes(x=as.numeric(ate_0), y=varname), color="gray") + 
  geom_vline(xintercept=0) + theme_bw() + 
  xlab("CATE") + ylab("Variable") + 
  theme(text=element_text(size=20))


# repeat process for turnout after voter registration

# focus on experimental and control groups and on registered participants
datb <- dat %>% filter(ra_group_factor!="section 8") %>%
  mutate(treat = case_when(ra_group_factor=="experimental"~1,
                           ra_group_factor=="control"~0)) %>%
  filter(age_group=="young_kid") %>%
  filter(!is.na(r_postregturnout))

# matrix of possible moderators
xmat <- datb %>% dplyr::select(treat, x_f_ad_edged, x_f_ad_edgradhs, x_f_ad_edgradhs_miss,
                               x_f_ad_edinsch,
                               x_f_ad_36_40, x_f_ad_41_45, x_f_ad_46_50,
                               x_f_ad_ethn_hisp,
                               x_f_ad_le_35,
                               x_f_ad_male, x_f_ad_nevmarr, x_f_ad_parentu18,
                               x_f_ad_race_black, x_f_ad_race_other, x_f_ad_working,
                               x_f_hh_afdc, x_f_hh_car, x_f_hh_disabl,
                               x_f_hh_noteens, x_f_hh_size2, x_f_hh_size3,
                               x_f_hh_size4, x_f_hh_victim, x_f_hood_5y,
                               x_f_hood_chat, x_f_hood_nbrkid, x_f_hood_nofamily,
                               x_f_hood_nofriend, x_f_hood_unsafenit,
                               x_f_hood_verydissat, x_f_hous_fndapt, x_f_hous_mov3tm,
                               x_f_hous_movdrgs, x_f_hous_movschl, x_f_hous_sec8bef,
                               x_f_site_balt, x_f_site_bos, x_f_site_chi, x_f_site_la,
                               x_f_c2_hosp, x_f_c2_hosp_miss, x_f_c2_lowbw, x_f_c2_lowbw_miss,
                               x_f_c2_read, x_f_c2_read_miss,
                               x_f_ch_age10, x_f_ch_age11, x_f_ch_age12, x_f_ch_age13,
                               x_f_ch_age14, x_f_ch_age15, x_f_ch_age16, x_f_ch_age17,
                               x_f_ch_age18, x_f_ch_age19, x_f_ch_age20, x_f_ch_male,
                               x_f_ch_schplay, x_f_ch_schplay_miss, x_f_ch_specmed)

# run model
b.fit <- bartMachine(X = xmat,
                     y = datb$r_postregturnout,
                     num_trees = 50)
b.fit.arr <- bartMachineArr(b.fit, R=10)
summary(b.fit.arr[[1]])
plot_y_vs_yhat(b.fit.arr[[1]], prediction_intervals = TRUE)

# check variable importance for each run and note variables appearing in top 10
investigate_var_importance(b.fit.arr[[1]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[2]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[3]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[4]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[5]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[6]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[7]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[8]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[9]], num_replicates_for_avg = 10)
investigate_var_importance(b.fit.arr[[10]], num_replicates_for_avg = 10)

#grab top 10 predictors in each model plus treatment
xmat <- datb %>% dplyr::select(treat, 
                               x_f_ad_edinsch, ##
                               x_f_hood_unsafenit, #######
                               x_f_ch_male, #######
                               x_f_ch_specmed, ##
                               x_f_ad_41_45, ###
                               x_f_c2_hosp_miss, #####
                               x_f_c2_lowbw,#####
                               x_f_ad_46_50, ########
                               x_f_ch_age11, #
                               x_f_c2_read, #
                               x_f_ch_age12, #
                               x_f_c2_lowbw_miss, ####
                               x_f_hood_nofriend, #
                               x_f_hous_mov3tm, ######
                               x_f_ad_le_35, ##
                               x_f_ad_male, #####
                               x_f_ch_age16, #####
                               x_f_ch_age12,
                               x_f_ch_age10, ##
                               x_f_ch_age11, ###
                               x_f_ch_schplay_miss,
                               x_f_ch_age15, #
                               x_f_ad_race_other, #
                               x_f_ad_edgradhs_miss,
                               x_f_ad_working,
                               x_f_ch_age19,
                               x_f_ad_ethn_hisp,
                               x_f_hh_afdc)
#rerun on just those predictors
b.fit <- bartMachine(X = xmat,
                     y = datb$r_postregturnout,
                     num_trees = 50)
b.fit.arr <- bartMachineArr(b.fit, R=10)
# predict on new matrices, find diffs:

# all vars with >1 appearance 
htevars <- c("treat",
             "x_f_ad_edinsch", "x_f_hood_unsafenit", "x_f_ch_male", "x_f_ch_specmed", 
             "x_f_ad_41_45", "x_f_c2_hosp_miss", "x_f_c2_lowbw", "x_f_ad_46_50",
             "x_f_ch_age11", "x_f_c2_read", "x_f_ch_age12", "x_f_c2_lowbw_miss",
             "x_f_hood_nofriend", "x_f_hous_mov3tm", "x_f_ad_le_35","x_f_ad_male",
             "x_f_ch_age16", "x_f_ch_age10", "x_f_ch_age11", "x_f_ch_age15", 
             "x_f_ad_race_other")

tab2 <- hte_cal(var="x_f_ad_edinsch")
for(i in 3:length(htevars)){
  tab2 <- rbind(tab2, hte_cal(htevars[i]))
  
}

# plot differences
ggplot(as.data.frame(tab2) %>%
         mutate(varname = recode(varname, 
                                 x_f_hous_mov3tm="Moved >3 times in 5 yrs",
                                 x_f_hood_nofriend="No friends in nbhd",
                                 x_f_ch_specmed="Req. special medicine",
                                 x_f_ch_male="Sex: Male",
                                 x_f_ch_age16="Age: 16",
                                 x_f_ch_age15="Age: 15",
                                 x_f_ch_age12="Age: 12",
                                 x_f_ch_age11="Age: 11",
                                 x_f_ch_age10="Age: 10",
                                 x_f_c2_read="Read to child",
                                 x_f_c2_lowbw_miss="Missing: low birthweight",
                                 x_f_c2_lowbw="Low birthweight",
                                 x_f_c2_hosp_miss="Missing: hospitalized",
                                 x_f_ad_race_other="Adult race: other",
                                 x_f_ad_male="Adult sex: male",
                                 x_f_ad_le_35="Adult age: <35",
                                 x_f_ad_edinsch="Adult educ: in school",
                                 x_f_ad_46_50="Adult age: 46-50",
                                 x_f_ad_41_45="Adult age: 41-45",
                                 x_f_hood_unsafenit="Neighborhood unsafe"))) +
  geom_point(aes(x=as.numeric(ate_1), y=varname)) + 
  geom_point(aes(x=as.numeric(ate_0), y=varname), color="gray") + 
  geom_vline(xintercept=0) + theme_bw() + 
  xlab("CATE") + ylab("Variable") + 
  theme(text=element_text(size=20))

